import pandas as pd
import geopandas as gpd
import numpy as np
from rasterio.transform import Affine
from rasterio.features import geometry_mask
from datetime import datetime, timedelta
import sys

def idw_interpolation(x, y, values, xi, yi, power=2):
    # Calculate the inverse distance weights
    dist = np.sqrt((x[:, np.newaxis] - xi[np.newaxis, :]) ** 2 + (y[:, np.newaxis] - yi[np.newaxis, :]) ** 2)
    weights = 1.0 / dist ** power
    weights[np.isinf(weights)] = 0  # Handle division by zero
    weights /= weights.sum(axis=0)

    return np.dot(values, weights)


# Perform spatial join for a day for all the variables
def spatial_join(boundary_gdf, year_gdf, filtered_column, year, month, day, latitude_column, longitude_column):
    print(f'Spatial Join Process Started for {year}-{month:02d}-{day:02d}')

    # Filter and clean the data
    gdf = year_gdf[(year_gdf[filtered_column] != -999) & (year_gdf[filtered_column] != -996)].dropna()

    # Extract the coordinates and values
    x_temp = gdf[longitude_column].values
    y_temp = gdf[latitude_column].values
    value_temp = gdf[filtered_column].values

    # Define the grid
    grid_longitude = np.linspace(min(x_temp), max(x_temp), num=100)
    grid_latitude = np.linspace(min(y_temp), max(y_temp), num=100)
    grid_x, grid_y = np.meshgrid(grid_longitude, grid_latitude)

    # Perform IDW interpolation
    Z_idw = idw_interpolation(x_temp, y_temp, value_temp, grid_x.flatten(), grid_y.flatten()).reshape(grid_x.shape)

    # Calculate mean values for each boundary
    mean_values = []
    transform = Affine.translation(min(grid_longitude), min(grid_latitude)) * Affine.scale(
        (max(grid_longitude) - min(grid_longitude)) / 100, (max(grid_latitude) - min(grid_latitude)) / 100)

    for geom in boundary_gdf.geometry:
        boundary_mask = geometry_mask([geom], transform=transform, invert=True, out_shape=Z_idw.shape)
        masked_data = Z_idw[boundary_mask]
        # print(f'Masked data size for {year}-{month:02d}-{day:02d}: {masked_data.size}')
        mean_value = np.nanmean(masked_data) if masked_data.size > 0 else np.nan
        mean_values.append(mean_value)

    boundary_gdf[f'MEAN_{filtered_column}'] = mean_values
    boundary_gdf['Date'] = f'{year}-{month:02d}-{day:02d}'

    return boundary_gdf


# Extract the CSV file data for a single day
def process_csv_file(df, date):
    df['Date'] = pd.to_datetime(df['Date'])
    return df[df['Date'] == date]


# Process the spatial CSV file
def process_csv_spatial_file(gdf, filename, shp_filename):
    # Save geometry separately
    geometry = gdf.geometry

    # Drop columns in the GeoDataFrame except for geometry
    gdf = gdf.drop(
        ['COUNTYFP', 'COUNTYNS', 'LSAD', 'CLASSFP', 'MTFCC', 'CSAFP', 'CBSAFP', 'METDIVFP', 'FUNCSTAT', 'ALAND',
         'AWATER', 'NAMELSAD'], axis=1)

    # Rename the columns in the GeoDataFrame
    gdf.rename(columns={'GEOID': 'CntyFIPS', 'NAME': 'CntyName', 'STATEFP': 'StateFIPS', 'INTPTLAT': 'Centroid_Lat',
                        'INTPTLON': 'Centroid_Lon'}, inplace=True)
    gdf['StateName'] = 'Oklahoma'
    rearrange_columns = ['StateName', 'StateFIPS', 'CntyName', 'CntyFIPS', 'Centroid_Lat', 'Centroid_Lon']
    remaining_columns = [col for col in gdf.columns if col not in rearrange_columns]
    gdf = gdf[rearrange_columns + remaining_columns]
    gdf.to_csv(filename, index=False)

    # Create a new GeoDataFrame with the original geometry
    gdf = gpd.GeoDataFrame(gdf, geometry=geometry, crs='EPSG:4269')
    gdf.to_file(shp_filename, driver='GeoJSON')


# Process IDW analysis per day and save the shapefile for every single year
def environmental_idw(csv_file_path, boundary_shape_file_path, filtered_column, latitude_column,
                            longitude_column):
    env_csv_data = gpd.read_file(csv_file_path)  # Environmental data CSV file
    boundary_data = gpd.read_file(boundary_shape_file_path)  # Boundary shape file

    # Ensure both datasets use the same CRS (EPSG:4269)
    env_csv_data = gpd.GeoDataFrame(env_csv_data, geometry=gpd.points_from_xy(env_csv_data[longitude_column],
                                                                              env_csv_data[latitude_column]),
                                    crs='EPSG:4269')
    boundary_data = boundary_data.to_crs('EPSG:4269')
    env_csv_data = env_csv_data.dropna(subset=['YEAR', 'MONTH', 'DAY'])
    env_csv_data['YEAR'] = env_csv_data['YEAR'].astype(int)
    env_csv_data['Date'] = pd.to_datetime(env_csv_data[['YEAR', 'MONTH', 'DAY']])
    years = env_csv_data['YEAR'].unique()
    yearly_data = gpd.GeoDataFrame()
    for year in years:
        start_date = datetime(year, 1, 1)
        end_date = datetime(year, 12, 31)
        delta = timedelta(days=1)

        while start_date <= end_date:
            daily_data = process_csv_file(env_csv_data, start_date)
            if daily_data.empty:
                print(f'No data found for {start_date.strftime("%Y-%m-%d")}, continuing to next date...')
                start_date += delta
                continue

            print(f'Processing CSV data for {start_date.strftime("%Y-%m-%d")}')
            daily_results = spatial_join(boundary_data.copy(), daily_data, filtered_column, year, start_date.month,
                                         start_date.day, latitude_column, longitude_column)
            yearly_data = pd.concat([yearly_data, daily_results])
            start_date += delta

    process_csv_spatial_file(yearly_data, 'idw_data.csv', 'idw_shape_data.geojson')




# Example usage
if __name__ == "__main__":
    # Sample file paths
    file_path = sys.argv[1]
    boundary_shp_file_path = sys.argv[2]
    target_column_name = sys.argv[3]
    latitude_column = sys.argv[4]
    longitude_column = sys.argv[5]
    print('arguments passed to the script are:',sys.argv)
    environmental_idw(file_path, boundary_shp_file_path, target_column_name, latitude_column, longitude_column)


